{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quick note on Transformer Block Unification\n",
    "\n",
    "Fun aside is that the MLP blocks of a Transformer are actually identical to the Self-Attention blocks of the transformer, with a few changes:\n",
    "\n",
    "- the query projection is missing, the data itself is the query\n",
    "- the key, value are data-independent parameters (i.e. the MLP block is really just a cross-attention \"soft lookup\" into a fixed {key:value} table\n",
    "- the Softmax (map/reduce non-linearity) is replaced with GeLU (map-only non-linearity)\n",
    "- the final Linear projection back to the residual pathway is missing\n",
    "\n",
    "This immediately suggests a unification of these blocks into a more general Transformer \"superblock\", that is simply wired up to the residual pathway either in parallel (e.g. as in all the heads of a multi-headed self-attention), or in series (as usually done from block to block otherwise). It also suggests in-between generalizations, e.g. multi-headed attention suggests the equivalent use of \"groups\" in Linear (or Conv) layers. Alternatively, attention could be done over two pools of nodes simultaneously: those where key,value are data-dependent and those that aren't, dispensing with the need for a distinction.\n",
    "\n",
    "**TLDR**: A much simpler Transformer with a single type of block wired up to a residual pathway in both parallel and in series is possible but to my knowledge has not yet been convincingly achieved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from einops import rearrange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 128])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# toy example\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 8,512,128\n",
    "n_head = 4\n",
    "n_embd = C\n",
    "x = torch.randn(B, T, C)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 128])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# self-attention block of a single Head\n",
    "nh = C//n_head # head size, 128/4 = 32\n",
    "key = nn.Linear(C, nh, bias=False)\n",
    "query = nn.Linear(C, nh, bias=False)\n",
    "value = nn.Linear(C, nh, bias=False)\n",
    "proj = nn.Linear(nh, C, bias=False)\n",
    "k = key(x)\n",
    "q = query(x)\n",
    "v = value(x)\n",
    "att = torch.softmax(q @ k.transpose(-2,-1), dim=-1)\n",
    "y = att @ v\n",
    "r = proj(y) # standard self-attention blocks have one more Linear when back to residual pathway\n",
    "r.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 128])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# typical linear block on a Transformer\n",
    "layer1 = nn.Linear(C, C*4, bias=False)\n",
    "layer2 = nn.Linear(C*4, C, bias=False)\n",
    "l1 = F.gelu(layer1(x))\n",
    "l2 = layer2(l1) # projects back to residual pathway\n",
    "l2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 128])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# linear block is actually attention over a fixed (not data-dependent) {k:v} dict\n",
    "q = x # change 1: query is simply the input\n",
    "k = layer1.weight # key and value are data-independent learnable parameters\n",
    "v = layer2.weight.T\n",
    "att = F.gelu(q @ k.transpose(-2,-1)) # change 2: using gelu instead of softmax\n",
    "y = att @ v\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(l2 == y).all() # cool"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
